/* ************************************************************************
 * Copyright 2013 Advanced Micro Devices, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * ************************************************************************/


#ifndef KTEST_PATTERNS_H_
#define KTEST_PATTERNS_H_

static std::string loadFileCode =
"char*\n"
"loadFile(const char* path)\n"
"{\n"
"    FILE *f;\n"
"    long size;\n"
"    char *text;\n"
"\n"
"    f = fopen(path, \"r\");\n"
"    if (f == NULL) {\n"
"        return NULL;\n"
"    }\n"
"\n"
"    if (fseek(f, 0, SEEK_END) != 0) {\n"
"        fclose(f);\n"
"        return NULL;\n"
"    }\n"
"    size = ftell(f);\n"
"    if (size == -1) {\n"
"        fclose(f);\n"
"        return NULL;\n"
"    }\n"
"    if (fseek(f, 0, SEEK_SET) != 0) {\n"
"        fclose(f);\n"
"        return NULL;\n"
"    }\n"
"\n"
"    text = (char*)calloc(size + 1, 1);\n"
"    if (text == NULL) {\n"
"        fclose(f);\n"
"        return NULL;\n"
"    }\n"
"\n"
"    if (fread(text, 1, size, f) == 0) {\n"
"        free(text);\n"
"        fclose(f);\n"
"        return NULL;\n"
"    }\n"
"    fclose(f);\n"
"    return text;\n"
"}\n";

static std::string randomVectorCode =
"template<typename T>\n"
"void\n"
"randomVector(\n"
"    size_t N,\n"
"    T *X,\n"
"    int incx)\n"
"{\n"
"    size_t n;\n"
"    VectorAccessor<T, int> x(X, N, incx);\n"
"\n"
"    for (n = 0; n < N; n++) {\n"
"        x[n] = random<T>();\n"
"    }\n"
"}\n";

static std::string unitVectorCode =
"template<typename T>\n"
"void\n"
"unitVector(\n"
"    size_t N,\n"
"    T *X,\n"
"    int incx)\n"
"{\n"
"    size_t n;\n"
"    VectorAccessor<T, int> x(X, N, incx);\n"
"\n"
"    for (n = 0; n < N; n++) {\n"
"        x[n] = ONE<T>();\n"
"    }\n"
"}\n";

static std::string sawtoothVectorCode =
"template<typename T>\n"
"void\n"
"sawtoothVector(\n"
"    size_t N,\n"
"    T *X,\n"
"    int incx)\n"
"{\n"
"    T v;\n"
"    size_t n;\n"
"    VectorAccessor<T, int> x(X, N, incx);\n"
"\n"
"    v = ONE<T>();\n"
"    for (n = 0; n < N; n++) {\n"
"        x[n] = v;\n"
"        v = v + ONE<T>();\n"
"    }\n"
"}\n";

static std::string compareVectorsCode =
"template<typename T>\n"
"bool\n"
"compareVectors(\n"
"    size_t N,\n"
"    T *blasVector,\n"
"    T *naiveVector,\n"
"    int incx)\n"
"{\n"
"    size_t n;\n"
"    VectorAccessor<T, int> blas(blasVector, N, incx);\n"
"    VectorAccessor<T, int> naive(naiveVector, N, incx);\n"
"    T blasVal, naiveVal;\n"
"\n"
"    for (n = 0; n < N; n++) {\n"
"        blasVal = blas[n];\n"
"        naiveVal = naive[n];\n"
"        if (isNAN(blasVal) && isNAN(naiveVal)) {\n"
"            continue;\n"
"        }\n"
"        if (blasVal != naiveVal) {\n"
"            return false;\n"
"        }\n"
"    }\n"
"    return true;\n"
"}\n";

static std::string compareMatricesCode =
"template<typename T>\n"
"bool\n"
"compareMatrices(\n"
"    clblasOrder order,\n"
"    size_t rows,\n"
"    size_t columns,\n"
"    T *blasMatrix,\n"
"    T *naiveMatrix,\n"
"    size_t ld)\n"
"{\n"
"    size_t r, c;\n"
"    MatrixAccessor<T> blas(blasMatrix, order, clblasNoTrans, rows, columns, ld);\n"
"    MatrixAccessor<T> naive(naiveMatrix, order, clblasNoTrans, rows, columns, ld);\n"
"    T blasVal, naiveVal;\n"
"\n"
"    for (r = 0; r < rows; r++) {\n"
"        for (c = 0; c < columns; c++) {\n"
"            blasVal = blas[r][c];\n"
"            naiveVal = naive[r][c];\n"
"            if (isNAN(blasVal) && isNAN(naiveVal)) {\n"
"                continue;\n"
"            }\n"
"            if (blasVal != naiveVal) {\n"
"                return false;\n"
"            }\n"
"        }\n"
"    }\n"
"    return true;\n"
"}\n";

static std::string randomMatrixCode =
"\n"
"template<typename T>\n"
"void\n"
"randomMatrix(\n"
"    clblasOrder order,\n"
"    size_t rows,\n"
"    size_t columns,\n"
"    T *A,\n"
"    size_t lda)\n"
"{\n"
"    size_t r, c;\n"
"    MatrixAccessor<T> a(A, order, clblasNoTrans, rows, columns, lda);\n"
"\n"
"    for (r = 0; r < rows; r++) {\n"
"        for (c = 0; c < columns; c++) {\n"
"            a[r][c] = random<T>();\n"
"        }\n"
"    }\n"
"}\n";

static std::string unitMatrixCode =
"\n"
"template<typename T>\n"
"void\n"
"unitMatrix(\n"
"    clblasOrder order,\n"
"    size_t rows,\n"
"    size_t columns,\n"
"    T *A,\n"
"    size_t lda)\n"
"{\n"
"    size_t r, c;\n"
"    MatrixAccessor<T> a(A, order, clblasNoTrans, rows, columns, lda);\n"
"\n"
"    for (r = 0; r < rows; r++) {\n"
"        for (c = 0; c < columns; c++) {\n"
"            a[r][c] = ONE<T>();\n"
"        }\n"
"    }\n"
"}\n";

static std::string sawtoothMatrixCode =
"\n"
"template<typename T>\n"
"void\n"
"sawtoothMatrix(\n"
"    clblasOrder order,\n"
"    size_t rows,\n"
"    size_t columns,\n"
"    T *A,\n"
"    size_t lda)\n"
"{\n"
"    size_t step;\n"
"    T v;\n"
"    size_t r, c;\n"
"    MatrixAccessor<T> a(A, order, clblasNoTrans, rows, columns, lda);\n"
"\n"
"    step = sqrt(rows);\n"
"    v = ONE<T>();\n"
"\n"
"    for (r = 0; r < rows; r++) {\n"
"        if ((r != 0) && (r % step == 0)) {\n"
"            v = v + ONE<T>();\n"
"        }\n"
"        for (c = 0; c < columns; c++) {\n"
"            a[r][c] = v;\n"
"        }\n"
"    }\n"
"}\n";

static std::string setUpTRSMDiagonalCode =
"template<typename T>\n"
"void\n"
"setUpTRSMDiagonal(\n"
"    clblasOrder order,\n"
"    clblasSide side,\n"
"    clblasUplo uplo,\n"
"    clblasTranspose transA,\n"
"    clblasDiag diag,\n"
"    size_t M,\n"
"    size_t N,\n"
"    T alpha,\n"
"    T *A,\n"
"    size_t lda,\n"
"    T *B,\n"
"    size_t ldb)\n"
"{\n"
"    size_t sizeA = (side == clblasRight) ? N : M;\n"
"\n"
"    if (diag == clblasNonUnit) {\n"
"        size_t k = side == clblasLeft ? M : N;\n"
"        MatrixAccessor<T> a(A, order, clblasNoTrans, k, k, lda);\n"
"        for (cl_uint i = 0; i < sizeA; i++) {\n"
"            a[i][i] = ONE<T>();\n"
"        }\n"
"        double ub = UPPER_BOUND<T>();\n"
"        while (ub >= 1) {\n"
"            size_t i = rand() % k;\n"
"            a[i][i] = a[i][i] * TWO<T>();\n"
"            ub /= 2;\n"
"        }\n"
"        \n"
"    }\n"
"    NaiveBlas::trmm(order, side, uplo, transA, diag, M, N, alpha, A, lda, B, ldb);\n"
"}\n"
"\n";

static std::string forwardDeclarationsCode =
"cl_platform_id getPlatform(const char *name);\n"
"cl_device_id getDevice(cl_platform_id platform, const char *name);\n"
"cl_kernel createKernel(const char *source, cl_context context,\n"
"    const char* options, cl_int *error);\n"
"void printExecTime(cl_ulong ns);\n";

static std::string getPlatformCode =
"cl_platform_id\n"
"getPlatform(const char *name)\n"
"{\n"
"    cl_int err;\n"
"    cl_uint nrPlatforms, i;\n"
"    cl_platform_id *list, platform;\n"
"    char platformName[64];\n"
"\n"
"    err = clGetPlatformIDs(0, NULL, &nrPlatforms);\n"
"    if (err != CL_SUCCESS) {\n"
"        return NULL;\n"
"    }\n"
"\n"
"    list = (cl_platform_id*)calloc(nrPlatforms, sizeof(*list));\n"
"    if (list == NULL) {\n"
"        return NULL;\n"
"    }\n"
"\n"
"    err = clGetPlatformIDs(nrPlatforms, list, NULL);\n"
"    if (err != CL_SUCCESS) {\n"
"        free(list);\n"
"        return NULL;\n"
"    }\n"
"\n"
"    platform = NULL;\n"
"    for (i = 0; i < nrPlatforms; i++) {\n"
"        err = clGetPlatformInfo(list[i], CL_PLATFORM_NAME,\n"
"            sizeof(platformName), platformName, NULL);\n"
"        if ((err == CL_SUCCESS) && (strcmp(platformName, name) == 0)) {\n"
"            platform = list[i];\n"
"            break;\n"
"        }\n"
"    }\n"
"\n"
"    free(list);\n"
"    return platform;\n"
"}\n";

static std::string getDeviceCode =
"cl_device_id\n"
"getDevice(\n"
"    cl_platform_id platform,\n"
"    const char *name)\n"
"{\n"
"\n"
"    cl_int err;\n"
"    cl_uint nrDevices, i;\n"
"    cl_device_id *list, device;\n"
"    char deviceName[64];\n"
"\n"
"    err = clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &nrDevices);\n"
"    if (err != CL_SUCCESS) {\n"
"        return NULL;\n"
"    }\n"
"    list = (cl_device_id*)calloc(nrDevices, sizeof(*list));\n"
"    if (list == NULL) {\n"
"        return NULL;\n"
"    }\n"
"\n"
"    err = clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, nrDevices, list, NULL);\n"
"    if (err != CL_SUCCESS) {\n"
"        free(list);\n"
"        return NULL;\n"
"    }\n"
"\n"
"    device = NULL;\n"
"    for (i = 0; i < nrDevices; i++) {\n"
"        err = clGetDeviceInfo(list[i], CL_DEVICE_NAME,\n"
"            sizeof(deviceName), deviceName, NULL);\n"
"        if ((err == CL_SUCCESS) && (strcmp(deviceName, name) == 0)) {\n"
"            device = list[i];\n"
"            break;\n"
"        }\n"
"    }\n"
"\n"
"    free(list);\n"
"    return device;\n"
"}\n";

static std::string createKernelCode =
"cl_kernel\n"
"createKernel(\n"
"    const char* source,\n"
"    cl_context context,\n"
"    const char* options,\n"
"    cl_int* error)\n"
"{\n"
"\n"
"    cl_int err;\n"
"    cl_device_id device;\n"
"    cl_program program;\n"
"    cl_kernel kernel;\n"
"    size_t logSize;\n"
"    char *log;\n"
"\n"
"    err = clGetContextInfo(context, CL_CONTEXT_DEVICES, sizeof(device), &device, NULL);\n"
"    if (err != CL_SUCCESS) {\n"
"        if (error != NULL) {\n"
"            *error = err;\n"
"        }\n"
"        return NULL;\n"
"    }\n"
"\n"
"    program = clCreateProgramWithSource(context, 1, &source, NULL, error);\n"
"    if (program == NULL) {\n"
"        return NULL;\n"
"    }\n"
"\n"
"    err = clBuildProgram(program, 1, &device, options, NULL, NULL);\n"
"    if (err != CL_SUCCESS) {\n"
"        logSize = 0;\n"
"        clGetProgramBuildInfo(program, device, CL_PROGRAM_BUILD_LOG, 0, NULL, &logSize);\n"
"        log = (char*)calloc(1, logSize + 1);\n"
"        clGetProgramBuildInfo(program, device, CL_PROGRAM_BUILD_LOG, logSize, log, NULL);\n"
"        printf(\"=== Build log ===\\n%s\\n\", log);\n"
"        free(log);\n"
"        clReleaseProgram(program);\n"
"        if (error != NULL) {\n"
"            *error = err;\n"
"        }\n"
"        return NULL;\n"
"    }\n"
"\n"
"    kernel = NULL;\n"
"    err = clCreateKernelsInProgram(program, 1, &kernel, NULL);\n"
"    clReleaseProgram(program);\n"
"    if (error != NULL) {\n"
"        *error = err;\n"
"    }\n"
"    return kernel;\n"
"}\n";

static std::string printTimeCode =
"void\n"
"printExecTime(cl_ulong ns)\n"
"{\n"
"    if (ns > 10000000) {\n"
"        printf(\"Kernel execution time: %lu milliseconds\\n\", ns / 1000000);\n"
"    }\n"
"    else if (ns > 10000) {\n"
"        printf(\"Kernel execution time: %lu microseconds\\n\", ns / 1000);\n"
"    }\n"
"    else {\n"
"        printf(\"Kernel execution time: %lu nanoseconds\\n\", ns);\n"
"    }\n"
"}\n";

#endif /* KTEST_PATTERNS_H_ */
